{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "import os\n",
    "#os.environ['create_image']='True'# Train mobilenet image classifier\n",
    "Currenly uses mobilenet V2 to train an image classifier.\n",
    "\n",
    "Future work:  \n",
    "[] support different deep learning neural network architectures (like VGG, ResNet, ...)  \n",
    "[] expose hyper parameters of those architectures as input to this component  \n",
    "[] support hyperparameter tuning (iteally using Katib)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "#os.environ['create_image']='True'\n",
    "os.environ['repository']='romeokienzler'\n",
    "os.environ['version']='0.1'\n",
    "#\n",
    "#os.environ['install_requirements']='True'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "papermill": {
     "duration": 2.413514,
     "end_time": "2021-01-28T15:59:34.938111",
     "exception": false,
     "start_time": "2021-01-28T15:59:32.524597",
     "status": "completed"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sending build context to Docker daemon  72.19kB\n",
      "Step 1/4 : FROM registry.access.redhat.com/ubi8/python-38\n",
      " ---> fd4f06020ce7\n",
      "Step 2/4 : RUN pip install nvflare==2.0.16\n",
      " ---> Using cache\n",
      " ---> 1b4a2b31335b\n",
      "Step 3/4 : RUN pip install tensorflow==2.9.1\n",
      " ---> Using cache\n",
      " ---> 0562edb5a8ed\n",
      "Step 4/4 : ADD nvflare.ipynb /\n",
      " ---> Using cache\n",
      " ---> d5eddadd6193\n",
      "Successfully built d5eddadd6193\n",
      "Successfully tagged claimed-train-nvflare:0.1\n",
      "The push refers to repository [docker.io/romeokienzler/claimed-train-nvflare]\n",
      "\n",
      "\u001b[1B274dd023: Preparing \n",
      "\u001b[1Bda8b6e89: Preparing \n",
      "\u001b[1B2b13e91b: Preparing \n",
      "\u001b[1B71335733: Preparing \n",
      "\u001b[1Bcbc5f0d9: Preparing \n",
      "\u001b[1B3ed1330a: Waiting g \n",
      "\u001b[1B11fd02f0: Preparing \n",
      "\u001b[8B274dd023: Pushed lready exists 1kB\u001b[8A\u001b[2K\u001b[5A\u001b[2K\u001b[6A\u001b[2K\u001b[2A\u001b[2K\u001b[1A\u001b[2K\u001b[3A\u001b[2K\u001b[8A\u001b[2K0.1: digest: sha256:a5017c4a58565bfe45e49810bdeac45d72f61b78c3c2e34e5c82a5cea3506f86 size: 2007\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "if bool(os.environ.get('create_image',False)):\n",
    "    docker_file=\"\"\"\n",
    "    FROM registry.access.redhat.com/ubi8/python-38\n",
    "    RUN pip install nvflare==2.0.16\n",
    "    RUN pip install tensorflow==2.9.1\n",
    "    ADD nvflare.ipynb /\n",
    "    #ENTRYPOINT [\"ipython\",\"/nvflare.ipynb\"]\n",
    "    \"\"\"\n",
    "    with open(\"Dockerfile\", \"w\") as text_file: \n",
    "        text_file.write(docker_file)\n",
    "\n",
    "    !docker build -t claimed-train-nvflare:`echo $version` .\n",
    "    !docker tag claimed-train-nvflare:`echo $version` `echo $repository`/claimed-train-nvflare:`echo $version`\n",
    "    !docker push `echo $repository`/claimed-train-nvflare:`echo $version`\n",
    "elif bool(os.environ.get('install_requirements',False)):\n",
    "    !pip install nvflare==2.0.16 tensorflow==2.9.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!docker tag claimed-train-nvflare echo $repository`/claimed-train-nvflare\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tensorflow as tf\n",
    "# from tensorflow.keras.applications import ResNet50V2 # , MobileNetV3Small\n",
    "from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout\n",
    "from tensorflow.keras import Model\n",
    "from claimed_utils import unzip, zipdir\n",
    "import os.path\n",
    "import glob\n",
    "from io import BytesIO\n",
    "from minio import Minio\n",
    "from ansible.module_utils.parsing.convert_bool import boolean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @dependency codait_utils.ipynb\n",
    "# @param image_shape shape the images shall be scaled to and the\n",
    "# model then will accept\n",
    "# @param model zip file name\n",
    "# @param data zip file name\n",
    "# @param model folder name\n",
    "# @param data folder name\n",
    "# @param epochs number of epochs to train\n",
    "# @param checkpoint activate checkpointing\n",
    "# @param checkpoint_ip minio endpoint\n",
    "# @param checkpoint_user minio user\n",
    "# @param checkpoint_pass minio pw\n",
    "# @param checkpoint_bucket minio bucket\n",
    "# @returns model zip file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_shape = os.environ.get('image_shape', '400,400')\n",
    "model_zip = os.environ.get('model_zip', 'model.zip')\n",
    "data_zip = os.environ.get('data_zip', 'data.zip')\n",
    "model_folder = os.environ.get('model', 'model')\n",
    "data = os.environ.get('data', 'data')\n",
    "epochs = int(os.environ.get('epochs', 1))\n",
    "checkpoint = boolean(os.environ.get('checkpoint', False))\n",
    "checkpoint_ip = os.environ.get('checkpoint_ip')\n",
    "checkpoint_user = os.environ.get('checkpoint_user', 'minio')\n",
    "checkpoint_pass = os.environ.get('checkpoint_pass', 'minio123')\n",
    "checkpoint_bucket = os.environ.get('checkpoint_bucket', 'checkpoint')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exists = False\n",
    "\n",
    "if checkpoint:\n",
    "    client = Minio(checkpoint_ip,\n",
    "                   checkpoint_user,\n",
    "                   checkpoint_pass,\n",
    "                   secure=False)\n",
    "\n",
    "    objects = client.list_objects(checkpoint_bucket)\n",
    "    asset_name = model_zip\n",
    "    for obj in objects:\n",
    "        if asset_name == obj.object_name:\n",
    "            exists = True\n",
    "            client.fget_object(checkpoint_bucket, model_zip, model_zip)\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.029222,
     "end_time": "2021-01-28T15:59:37.986639",
     "exception": false,
     "start_time": "2021-01-28T15:59:37.957417",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "if not exists:\n",
    "    unzip('.', data_zip)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.822596,
     "end_time": "2021-01-28T15:59:38.817009",
     "exception": false,
     "start_time": "2021-01-28T15:59:37.994413",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "if not exists:\n",
    "    folder = glob.glob(data + \"/*\")\n",
    "    num_classes = len(folder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.127534,
     "end_time": "2021-01-28T15:59:38.952236",
     "exception": false,
     "start_time": "2021-01-28T15:59:38.824702",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "if not exists:\n",
    "    batch_size = 32\n",
    "    input_shape = 'dummy'  # make the compiler happy\n",
    "    exec('input_shape = (' + image_shape + ')')\n",
    "\n",
    "    train_ds = tf.keras.preprocessing.image_dataset_from_directory(\n",
    "        'data',\n",
    "        validation_split=0.2,\n",
    "        subset=\"training\",\n",
    "        seed=123,\n",
    "        image_size=input_shape,\n",
    "        batch_size=batch_size)\n",
    "\n",
    "    val_ds = tf.keras.preprocessing.image_dataset_from_directory(\n",
    "        'data',\n",
    "        validation_split=0.2,\n",
    "        subset=\"validation\",\n",
    "        seed=123,\n",
    "        image_size=input_shape,\n",
    "        batch_size=batch_size)\n",
    "\n",
    "    train_ds = train_ds.map(lambda x, y: (x, tf.one_hot(y, depth=num_classes)))\n",
    "    val_ds = val_ds.map(lambda x, y: (x, tf.one_hot(y, depth=num_classes)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.014987,
     "end_time": "2021-01-28T15:59:38.975558",
     "exception": false,
     "start_time": "2021-01-28T15:59:38.960571",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def my_net(model, num_classes, freeze_layers=10, full_freeze='N'):\n",
    "    x = model.output\n",
    "    x = GlobalAveragePooling2D()(x)\n",
    "    x = Dense(512, activation='relu')(x)\n",
    "    x = Dropout(0.5)(x)\n",
    "    x = Dense(512, activation='relu')(x)\n",
    "    x = Dropout(0.5)(x)\n",
    "    out = Dense(num_classes, activation='softmax')(x)\n",
    "    model_final = Model(model.input, out)\n",
    "    if full_freeze != 'N':\n",
    "        for layer in model.layers[0:freeze_layers]:\n",
    "            layer.trainable = False\n",
    "    return model_final"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.015812,
     "end_time": "2021-01-28T15:59:38.999942",
     "exception": false,
     "start_time": "2021-01-28T15:59:38.984130",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# model = ResNet50V2(weights='imagenet',include_top=False)\n",
    "# model = my_net(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 5.130591,
     "end_time": "2021-01-28T15:59:44.139230",
     "exception": false,
     "start_time": "2021-01-28T15:59:39.008639",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "if not exists:\n",
    "    exec('input_shape = (' + image_shape + ',3)')\n",
    "\n",
    "    model = tf.keras.applications.MobileNetV2(\n",
    "        input_shape=input_shape, alpha=1.0, include_top=False,\n",
    "        input_tensor=None, pooling=None, classes=num_classes,\n",
    "        classifier_activation='softmax'\n",
    "    )\n",
    "    model = my_net(model, num_classes=num_classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.031121,
     "end_time": "2021-01-28T15:59:44.197588",
     "exception": false,
     "start_time": "2021-01-28T15:59:44.166467",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# model = tf.keras.applications.VGG16(\n",
    "#     include_top=True, weights=None, input_tensor=None,\n",
    "#     input_shape=(244, 244, 3), pooling=None, classes=2,\n",
    "#     classifier_activation='softmax'\n",
    "# )\n",
    "# model = my_net(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.040951,
     "end_time": "2021-01-28T15:59:44.261989",
     "exception": false,
     "start_time": "2021-01-28T15:59:44.221038",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "if not exists:\n",
    "    model.compile(\n",
    "        optimizer=\"adam\",\n",
    "        loss='categorical_crossentropy',\n",
    "        metrics=['accuracy']\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 21.224432,
     "end_time": "2021-01-28T16:00:05.511392",
     "exception": false,
     "start_time": "2021-01-28T15:59:44.286960",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "if not exists:\n",
    "    model.fit(\n",
    "        train_ds,\n",
    "        batch_size=batch_size,\n",
    "        epochs=epochs,\n",
    "        validation_data=val_ds\n",
    "    )\n",
    "    model.save(model_folder)\n",
    "    zipdir(model_zip, model_folder)\n",
    "else:\n",
    "    print('Model cached, skipping training')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not exists:\n",
    "    size = os.path.getsize(model_zip)\n",
    "    with open(model_zip, 'rb') as fh:\n",
    "        buf = BytesIO(fh.read())\n",
    "        result = client.put_object(\n",
    "            checkpoint_bucket, model_zip, buf, length=size\n",
    "        )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 55.042719,
   "end_time": "2021-01-28T16:00:26.871724",
   "environment_variables": {},
   "exception": null,
   "input_path": "/home/jovyan/work/elyra-classification/train-trusted-ai.ipynb",
   "output_path": "/home/jovyan/work/elyra-classification/train-trusted-ai.ipynb",
   "parameters": {},
   "start_time": "2021-01-28T15:59:31.829005",
   "version": "2.2.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
